Skip to content

[Transform] Spinquant with R1 and R2 #1615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Aug 13, 2025
Merged

Conversation

brian-dellabetta
Copy link
Collaborator

@brian-dellabetta brian-dellabetta commented Jul 2, 2025

Purpose

  • Enable offline spinquant-style transforms

Prerequisites

Changes

  • Added spinquant_example.py to examples folder
  • Added SpinQuantModifier which handles the construction of a spinquant-style transform config

Testing

  • Added modifier serialization and correctness tests

Evaluation

Using this branch, and the original SpinQuant code, we see very similar results for meta-llama/Llama-3.2-1B-Instruct with W4A16 quantization. Results are equivalent in hf (in-memory vs serialized and re-loaded), and very similar in vllm. The symmetric scales calculation in llm-compressor is slightly different than original SpinQuant paper, which uses the original GPTQ implementation. When this is swapped in, results are consistent, with hadamard improving results on gsm8k_llama and arc_challenge_llama:

Scheme Impl gsm8k gsm8k_llama arc_challenge_llama
Hadamard+W4A16 LC 0.2403 0.2835 0.5262
W4A16 LC 0.1964 0.1933 0.4781
Hadamard+W4A16 LC+SQscales 0.1721 0.2183 0.485
W4A16 LC+SQscales 0.207 0.1706 0.4498
Hadamard+W4A16 SQ 0.1736 0.2282 0.4807
W4A16 SQ 0.1986 0.1774 0.4489

To run LC+SQScales, change this line in CT from

scales = max_val_pos / (float(bit_range) / 2)

to

scales = max_val_pos / (float(bit_max))
The following python script was used to generate these results

Clone SpinQuant repo and paste this in the top-level directory:

# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from typing import Literal
import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

from torch import nn
import lm_eval

from transformers import LlamaForCausalLM, AutoTokenizer
import transformers
from train_utils.main import prepare_model
from train_utils.modeling_llama_quant import LlamaForCausalLM as LlamaForCausalLMQuant
from utils.hadamard_utils import random_hadamard_matrix, hadamard_matrix
from utils.process_args import process_args_ptq

# model_id = "meta-llama/Llama-3.1-8B-Instruct"
# model_id = "meta-llama/Llama-3.2-3B-Instruct"
model_id = "meta-llama/Llama-3.2-1B-Instruct"
dtype = torch.bfloat16


class RotateModule(nn.Module):
    def __init__(self, R_init):
        super(RotateModule, self).__init__()
        self.weight = nn.Parameter(R_init.to(torch.float32).to(torch.device("cuda")))

    def forward(self, x, transpose=False):
        if transpose:
            return x @ self.weight
        else:
            return self.weight @ x


def get_sq_model(
    r1r2=Literal["eye", "random-hadamard", "hadamard"],
    w_bits=Literal[4, 16],
    w_clip: bool = False,
) -> LlamaForCausalLMQuant:
    model_args, training_args, ptq_args = process_args_ptq()
    model_args.input_model = model_id
    if w_bits == 4:
        ptq_args.w_bits = 4
        ptq_args.w_groupsize = 128
        ptq_args.w_rtn = True  # if False, GPTQ is used
        ptq_args.w_clip = w_clip
    ptq_args.a_bits = 16
    ptq_args.k_bits = 16
    ptq_args.v_bits = 16

    print("=======ARGS=======", ptq_args)

    config = transformers.AutoConfig.from_pretrained(model_args.input_model)

    # Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens
    process_word_embeddings = False
    if config.tie_word_embeddings:
        config.tie_word_embeddings = False
        process_word_embeddings = True

    model = LlamaForCausalLMQuant.from_pretrained(
        pretrained_model_name_or_path=model_args.input_model,
        config=config,
        torch_dtype=dtype,
        device_map="cuda",
    )

    if process_word_embeddings:
        model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()

    model = prepare_model(ptq_args, model)
    for param in model.parameters():
        param.requires_grad = False
    match r1r2:
        case "eye":
            R1 = torch.eye(model.config.hidden_size, device="cuda")
        case "random-hadamard":
            R1 = random_hadamard_matrix(model.config.hidden_size, "cuda")
        case _:
            R1 = hadamard_matrix(model.config.hidden_size, "cuda")
    model.R1 = RotateModule(R1)
    for i in range(model.config.num_hidden_layers):
        # Each head dim = 128 for Llama model
        match r1r2:
            case "eye":
                R2 = torch.eye(
                    model.config.hidden_size // model.config.num_attention_heads,
                    device="cuda",
                )
            case "random-hadamard":
                R2 = random_hadamard_matrix(
                    model.config.hidden_size // model.config.num_attention_heads, "cuda"
                )
            case _:
                R2 = hadamard_matrix(
                    model.config.hidden_size // model.config.num_attention_heads, "cuda"
                )
        model.model.layers[i].self_attn.R2 = RotateModule(R2)

    model.config.use_cache = False

    return model


def get_lc_model(
    r1r2=Literal["eye", "random-hadamard", "hadamard"], w_bits=Literal[4, 16]
) -> LlamaForCausalLM:
    from llmcompressor import oneshot
    from llmcompressor.modifiers.quantization import QuantizationModifier
    from llmcompressor.modifiers.transform import SpinQuantModifier

    model = LlamaForCausalLM.from_pretrained(
        pretrained_model_name_or_path=model_id,
        torch_dtype=dtype,
        device_map="cuda",
    )

    recipe = [
        SpinQuantModifier(
            rotations=[] if r1r2 == "eye" else ["R1", "R2"],
            transform_type="hadamard",
        )
    ]
    if w_bits == 4:
        recipe.append(
            QuantizationModifier(
                targets="Linear",
                scheme="W4A16",
                ignore=["lm_head"],
            )
        )

    oneshot(
        model=model,
        recipe=recipe,
        pipeline="datafree",
        log_dir=None,
    )

    return model


if __name__ == "__main__":
    for scales_impl in ["sq_min_hack", "lc_min_hack"]:
        for r1r2 in ["eye", "hadamard"]:
            for sq_lc in ["sq", "lc"]:
                w_bits = 4

                os.environ["SCALES_IMPL"] = scales_impl

                model = (
                    get_sq_model(r1r2=r1r2, w_bits=w_bits)
                    if sq_lc == "sq"
                    else get_lc_model(r1r2=r1r2, w_bits=w_bits)
                ).to("cuda")

                SAVE_DIR = model_id.split("/")[1] + f"-{scales_impl}-{r1r2}-w4a16"
                model.save_pretrained(SAVE_DIR, save_compressed=True)
                tokenizer = AutoTokenizer.from_pretrained(
                    model_id, trust_remote_code=True
                )
                tokenizer.save_pretrained(SAVE_DIR)

                del model
                del tokenizer
                torch.cuda.empty_cache()

                results = lm_eval.simple_evaluate(
                    # 1) hf in-memory
                    # model=lm_eval.models.huggingface.HFLM(
                    #     pretrained=model,
                    #     batch_size=32,
                    #     add_bos_token=False,
                    # ),
                    # 1/)
                    # 2) vllm serialized
                    model="vllm",
                    model_args={
                        "pretrained": SAVE_DIR,
                        "add_bos_token": False,
                        "dtype": "auto",
                        "max_model_len": 4096,
                        "gpu_memory_utilization": 0.5,
                        "enable_chunked_prefill": True,
                    },
                    # 2/)
                    # 3) hf serialized
                    # model="hf",
                    # model_args={
                    #     "pretrained": SAVE_DIR,
                    #     "add_bos_token": False,
                    #     "dtype": "auto",
                    # },
                    # device="cuda",
                    # 3/)
                    tasks=["gsm8k_llama", "gsm8k", "arc_challenge_llama"],
                    num_fewshot=8,
                    batch_size=32,
                    apply_chat_template=True,
                    fewshot_as_multiturn=True,
                )
                print(
                    f"RESULTS, {model_id} {sq_lc} R1R2 {r1r2} W_BITS {w_bits} SCALEIMPL {scales_impl}"
                )
                print(lm_eval.utils.make_table(results))

Follow Ups

  • Infer data free pipeline, even if a transform modifier is included
  • Rotations R3 and R4
  • Modify example to use GPTQ once basic evaluation has been performed

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @brian-dellabetta, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the TransformModifier by introducing support for predefined transformation configurations, known as presets. This allows users to easily apply complex transformation schemes like QUIP and SpinQuant, streamlining the process of applying advanced model compression techniques. The changes also include an updated example demonstrating the new functionality and improved validation for the modifier.

Highlights

  • Enhanced TransformModifier Flexibility: The TransformModifier now accepts either a preset_config string to load predefined transformation schemes (like QUIP or SpinQuant) or a direct config object for custom transformation setups, making it more versatile and user-friendly.
  • Introduction of Predefined Transformation Presets: New modules have been added under src/llmcompressor/modifiers/transform/presets to define and expose QUIP, QUIP_ONLINE, LLAMA_SPINQUANT, and LLAMA_SPINQUANT_R1R2 configurations. These presets simplify the application of complex transformation strategies based on research papers.
  • Updated Llama-3 Example: The llama3_example.py script has been revised to showcase the usage of the TransformModifier with a preset_config (specifically LLAMA_SPINQUANT_R1R2) and to use QuantizationModifier instead of GPTQModifier. The example also now uses a smaller Llama model for faster execution and includes a dispatch_for_generation call.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request significantly enhances the TransformModifier by introducing a robust preset configuration system and improving module targeting. The refactoring to use Pydantic for configuration validation greatly improves maintainability and prevents invalid states. The changes to use regex for module targeting in the presets (spinquant.py and quip.py) are a notable improvement for flexibility and robustness.

brian-dellabetta and others added 8 commits July 8, 2025 21:29
@kylesayrs kylesayrs changed the base branch from kylesayrs/transform-modifier to main July 11, 2025 18:52
kylesayrs added 11 commits July 11, 2025 14:58
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
kylesayrs and others added 5 commits July 16, 2025 11:13
Copy link
Collaborator Author

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good to go once neuralmagic/compressed-tensors#391 is in

@brian-dellabetta brian-dellabetta force-pushed the bdellabe/transform-modifier branch from 11662d7 to d0e5bc5 Compare July 24, 2025 21:46
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
brian-dellabetta and others added 6 commits August 7, 2025 13:59
@kylesayrs kylesayrs marked this pull request as ready for review August 8, 2025 04:43
kylesayrs and others added 2 commits August 8, 2025 00:49
@brian-dellabetta brian-dellabetta added the ready When a PR is ready for review label Aug 12, 2025
Signed-off-by: Brian Dellabetta <[email protected]>
Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the extra docstring, looks good

Copy link
Collaborator

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean! LGTM!

@brian-dellabetta brian-dellabetta merged commit 8747bae into main Aug 13, 2025
13 checks passed
@brian-dellabetta brian-dellabetta deleted the bdellabe/transform-modifier branch August 13, 2025 15:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready When a PR is ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants